{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IqM-T1RTzY6C"
      },
      "source": [
        "##### Copyright 2024 Google LLC."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "cellView": "form",
        "id": "dUKJPsGDPQlS"
      },
      "outputs": [],
      "source": [
        "# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bBKhPThKPac5"
      },
      "source": [
        "# Fine-tuning Gemma using Unsloth\n",
        "\n",
        "Welcome to this step-by-step guide on fine-tuning the [Gemma](https://huggingface.co/google/gemma-2b) using [Unsloth](https://unsloth.ai/).\n",
        "\n",
        "\n",
        "[**Gemma**](https://ai.google.dev/gemma) is a family of lightweight, state-of-the-art open models from Google, built from the same research and technology used to create the Gemini models. They are text-to-text, decoder-only large language models, available in English, with open weights, pre-trained variants, and instruction-tuned variants. Gemma models are well-suited for a variety of text generation tasks, including question answering, summarization, and reasoning. Their relatively small size makes it possible to deploy them in environments with limited resources such as a laptop, desktop or your own cloud infrastructure, democratizing access to state of the art AI models and helping foster innovation for everyone.\n",
        "\n",
        "[**Unsloth**](https://unsloth.ai/) is a Python framework developed for finetuning large language models like Gemma 2.4x faster, using 58% less VRAM, and with no degradation in accuracy.\n",
        "\n",
        "\n",
        "In this notebook, you will learn how to finetune Gemma 2 models using **Unsloth** in a Google Colab environment. You'll install the necessary packages, finetune the model, and run a sample inference.\n",
        "\n",
        "<table align=\"left\">\n",
        "<td>\n",
        " <a target=\"_blank\" href=\"https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/Gemma/[Gemma_2]Finetune_with_Unsloth.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "</td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0j8QtM7uTQNd"
      },
      "source": [
        "## Setup\n",
        "\n",
        "### Select the Colab runtime\n",
        "To complete this tutorial, you'll need to have a Colab runtime with sufficient resources to run the Gemma model. In this case, you can use a T4 GPU:\n",
        "\n",
        "1. In the upper-right of the Colab window, select **▾ (Additional connection options)**.\n",
        "2. Select **Change runtime type**.\n",
        "3. Under **Hardware accelerator**, select **T4 GPU**.\n",
        "\n",
        "### Hugging Face setup\n",
        "\n",
        "**Before you dive into the tutorial, let's get you set up with Hugging Face. This is needed to upload the finetuned model into Hugging Face Hub.**\n",
        "\n",
        "1. **Hugging Face Account:**  If you don't already have one, you can create a free Hugging Face account by clicking [here](https://huggingface.co/join).\n",
        "2. **Hugging Face Token:**  Generate a Hugging Face access (with `write` permission) token by clicking [here](https://huggingface.co/settings/tokens). You'll need this token later in the tutorial.\n",
        "\n",
        "**Once you've completed these steps, you're ready to move on to the next section where you'll set up environment variables in your Colab environment.**\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oMi-K8lVMdpE"
      },
      "source": [
        "### Configure your HF token\n",
        "\n",
        "Add your Hugging Face token to the Colab Secrets manager to securely store it.\n",
        "\n",
        "1. Open your Google Colab notebook and click on the 🔑 Secrets tab in the left panel. <img src=\"https://storage.googleapis.com/generativeai-downloads/images/secrets.jpg\" alt=\"The Secrets tab is found on the left panel.\" width=50%>\n",
        "2. Create a new secret with the name `HF_TOKEN`.\n",
        "3. Copy/paste your token key into the Value input box of `HF_TOKEN`.\n",
        "4. Toggle the button on the left to allow notebook access to the secret."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "pYAHbvbeMhIF"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "from google.colab import userdata\n",
        "\n",
        "# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env\n",
        "# vars as appropriate for your system.\n",
        "os.environ[\"HF_TOKEN\"] = userdata.get(\"HF_TOKEN\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QvC30lMdTZ4p"
      },
      "source": [
        "### Install dependencies\n",
        "\n",
        "You'll need to install a few Python packages and dependencies for Unsloth.\n",
        "\n",
        "Run the following cell to install or upgrade the necessary packages:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "2eSvM9zX_2d3"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Collecting unsloth\n",
            "  Downloading unsloth-2024.11.11-py3-none-any.whl.metadata (58 kB)\n",
            "\u001b[?25l     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/58.5 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.5/58.5 kB\u001b[0m \u001b[31m12.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting unsloth_zoo>=2024.11.8 (from unsloth)\n",
            "  Downloading unsloth_zoo-2024.11.8-py3-none-any.whl.metadata (16 kB)\n",
            "Requirement already satisfied: torch>=2.4.0 in /usr/local/lib/python3.10/dist-packages (from unsloth) (2.5.1+cu121)\n",
            "Collecting xformers>=0.0.27.post2 (from unsloth)\n",
            "  Downloading xformers-0.0.28.post3-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (1.0 kB)\n",
            "Collecting bitsandbytes (from unsloth)\n",
            "  Downloading bitsandbytes-0.44.1-py3-none-manylinux_2_24_x86_64.whl.metadata (3.5 kB)\n",
            "Collecting triton>=3.0.0 (from unsloth)\n",
            "  Downloading triton-3.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.3 kB)\n",
            "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from unsloth) (24.2)\n",
            "Collecting tyro (from unsloth)\n",
            "  Downloading tyro-0.9.2-py3-none-any.whl.metadata (9.4 kB)\n",
            "Requirement already satisfied: transformers>=4.46.1 in /usr/local/lib/python3.10/dist-packages (from unsloth) (4.46.2)\n",
            "Collecting datasets>=2.16.0 (from unsloth)\n",
            "  Downloading datasets-3.1.0-py3-none-any.whl.metadata (20 kB)\n",
            "Requirement already satisfied: sentencepiece>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from unsloth) (0.2.0)\n",
            "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from unsloth) (4.66.6)\n",
            "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from unsloth) (5.9.5)\n",
            "Requirement already satisfied: wheel>=0.42.0 in /usr/local/lib/python3.10/dist-packages (from unsloth) (0.45.0)\n",
            "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from unsloth) (1.26.4)\n",
            "Requirement already satisfied: accelerate>=0.34.1 in /usr/local/lib/python3.10/dist-packages (from unsloth) (1.1.1)\n",
            "Collecting trl!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,>=0.7.9 (from unsloth)\n",
            "  Downloading trl-0.12.1-py3-none-any.whl.metadata (10 kB)\n",
            "Requirement already satisfied: peft!=0.11.0,>=0.7.1 in /usr/local/lib/python3.10/dist-packages (from unsloth) (0.13.2)\n",
            "Collecting protobuf<4.0.0 (from unsloth)\n",
            "  Downloading protobuf-3.20.3-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl.metadata (679 bytes)\n",
            "Requirement already satisfied: huggingface_hub in /usr/local/lib/python3.10/dist-packages (from unsloth) (0.26.2)\n",
            "Collecting hf_transfer (from unsloth)\n",
            "  Downloading hf_transfer-0.1.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.7 kB)\n",
            "Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.34.1->unsloth) (6.0.2)\n",
            "Requirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.34.1->unsloth) (0.4.5)\n",
            "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from datasets>=2.16.0->unsloth) (3.16.1)\n",
            "Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets>=2.16.0->unsloth) (17.0.0)\n",
            "Collecting dill<0.3.9,>=0.3.0 (from datasets>=2.16.0->unsloth)\n",
            "  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)\n",
            "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets>=2.16.0->unsloth) (2.2.2)\n",
            "Requirement already satisfied: requests>=2.32.2 in /usr/local/lib/python3.10/dist-packages (from datasets>=2.16.0->unsloth) (2.32.3)\n",
            "Collecting xxhash (from datasets>=2.16.0->unsloth)\n",
            "  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)\n",
            "Collecting multiprocess<0.70.17 (from datasets>=2.16.0->unsloth)\n",
            "  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)\n",
            "Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets>=2.16.0->unsloth)\n",
            "  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)\n",
            "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets>=2.16.0->unsloth) (3.11.2)\n",
            "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface_hub->unsloth) (4.12.2)\n",
            "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=2.4.0->unsloth) (3.4.2)\n",
            "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=2.4.0->unsloth) (3.1.4)\n",
            "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.10/dist-packages (from torch>=2.4.0->unsloth) (1.13.1)\n",
            "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy==1.13.1->torch>=2.4.0->unsloth) (1.3.0)\n",
            "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.46.1->unsloth) (2024.9.11)\n",
            "Requirement already satisfied: tokenizers<0.21,>=0.20 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.46.1->unsloth) (0.20.3)\n",
            "Requirement already satisfied: rich in /usr/local/lib/python3.10/dist-packages (from trl!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,>=0.7.9->unsloth) (13.9.4)\n",
            "Collecting cut_cross_entropy (from unsloth_zoo>=2024.11.8->unsloth)\n",
            "  Downloading cut_cross_entropy-24.11.4-py3-none-any.whl.metadata (9.3 kB)\n",
            "Requirement already satisfied: pillow in /usr/local/lib/python3.10/dist-packages (from unsloth_zoo>=2024.11.8->unsloth) (11.0.0)\n",
            "Requirement already satisfied: docstring-parser>=0.16 in /usr/local/lib/python3.10/dist-packages (from tyro->unsloth) (0.16)\n",
            "Collecting shtab>=1.5.6 (from tyro->unsloth)\n",
            "  Downloading shtab-1.7.1-py3-none-any.whl.metadata (7.3 kB)\n",
            "Requirement already satisfied: typeguard>=4.0.0 in /usr/local/lib/python3.10/dist-packages (from tyro->unsloth) (4.4.1)\n",
            "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.16.0->unsloth) (2.4.3)\n",
            "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.16.0->unsloth) (1.3.1)\n",
            "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.16.0->unsloth) (24.2.0)\n",
            "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.16.0->unsloth) (1.5.0)\n",
            "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.16.0->unsloth) (6.1.0)\n",
            "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.16.0->unsloth) (0.2.0)\n",
            "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.16.0->unsloth) (1.17.2)\n",
            "Requirement already satisfied: async-timeout<6.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.16.0->unsloth) (4.0.3)\n",
            "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets>=2.16.0->unsloth) (3.4.0)\n",
            "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets>=2.16.0->unsloth) (3.10)\n",
            "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets>=2.16.0->unsloth) (2.2.3)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets>=2.16.0->unsloth) (2024.8.30)\n",
            "Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich->trl!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,>=0.7.9->unsloth) (3.0.0)\n",
            "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich->trl!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,>=0.7.9->unsloth) (2.18.0)\n",
            "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=2.4.0->unsloth) (3.0.2)\n",
            "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets>=2.16.0->unsloth) (2.8.2)\n",
            "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets>=2.16.0->unsloth) (2024.2)\n",
            "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets>=2.16.0->unsloth) (2024.2)\n",
            "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich->trl!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,>=0.7.9->unsloth) (0.1.2)\n",
            "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->datasets>=2.16.0->unsloth) (1.16.0)\n",
            "Downloading unsloth-2024.11.11-py3-none-any.whl (167 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m167.6/167.6 kB\u001b[0m \u001b[31m16.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading datasets-3.1.0-py3-none-any.whl (480 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m480.6/480.6 kB\u001b[0m \u001b[31m51.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading protobuf-3.20.3-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.1 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m170.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading triton-3.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (209.5 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m209.5/209.5 MB\u001b[0m \u001b[31m208.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading trl-0.12.1-py3-none-any.whl (310 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m310.9/310.9 kB\u001b[0m \u001b[31m241.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading unsloth_zoo-2024.11.8-py3-none-any.whl (59 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m59.1/59.1 kB\u001b[0m \u001b[31m268.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading xformers-0.0.28.post3-cp310-cp310-manylinux_2_28_x86_64.whl (16.7 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m16.7/16.7 MB\u001b[0m \u001b[31m242.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading bitsandbytes-0.44.1-py3-none-manylinux_2_24_x86_64.whl (122.4 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m122.4/122.4 MB\u001b[0m \u001b[31m203.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading hf_transfer-0.1.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.6 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.6/3.6 MB\u001b[0m \u001b[31m252.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading tyro-0.9.2-py3-none-any.whl (112 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m112.1/112.1 kB\u001b[0m \u001b[31m303.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m116.3/116.3 kB\u001b[0m \u001b[31m310.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading fsspec-2024.9.0-py3-none-any.whl (179 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m179.3/179.3 kB\u001b[0m \u001b[31m343.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading multiprocess-0.70.16-py310-none-any.whl (134 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m342.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading shtab-1.7.1-py3-none-any.whl (14 kB)\n",
            "Downloading cut_cross_entropy-24.11.4-py3-none-any.whl (22 kB)\n",
            "Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m328.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hInstalling collected packages: xxhash, triton, shtab, protobuf, hf_transfer, fsspec, dill, multiprocess, xformers, tyro, cut_cross_entropy, bitsandbytes, datasets, trl, unsloth_zoo, unsloth\n",
            "  Attempting uninstall: protobuf\n",
            "    Found existing installation: protobuf 4.25.5\n",
            "    Uninstalling protobuf-4.25.5:\n",
            "      Successfully uninstalled protobuf-4.25.5\n",
            "  Attempting uninstall: fsspec\n",
            "    Found existing installation: fsspec 2024.10.0\n",
            "    Uninstalling fsspec-2024.10.0:\n",
            "      Successfully uninstalled fsspec-2024.10.0\n",
            "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
            "gcsfs 2024.10.0 requires fsspec==2024.10.0, but you have fsspec 2024.9.0 which is incompatible.\n",
            "grpcio-status 1.62.3 requires protobuf>=4.21.6, but you have protobuf 3.20.3 which is incompatible.\u001b[0m\u001b[31m\n",
            "\u001b[0mSuccessfully installed bitsandbytes-0.44.1 cut_cross_entropy-24.11.4 datasets-3.1.0 dill-0.3.8 fsspec-2024.9.0 hf_transfer-0.1.8 multiprocess-0.70.16 protobuf-3.20.3 shtab-1.7.1 triton-3.1.0 trl-0.12.1 tyro-0.9.2 unsloth-2024.11.11 unsloth_zoo-2024.11.8 xformers-0.0.28.post3 xxhash-3.5.0\n"
          ]
        },
        {
          "data": {
            "application/vnd.colab-display-data+json": {
              "id": "0fe89fa18e094bf4a13b8fd1ffdde736",
              "pip_warning": {
                "packages": [
                  "google"
                ]
              }
            }
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "#   Install Unsloth library\n",
        "! pip install unsloth --upgrade --no-cache-dir\n",
        "\n",
        "# Install Flash Attention 2 for softcapping support\n",
        "import torch\n",
        "if torch.cuda.get_device_capability()[0] >= 8:\n",
        "    !pip install --no-deps packaging ninja einops \"flash-attn>=2.6.3\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "r2v_X2fA0Df5"
      },
      "source": [
        "## Finetuning Gemma 2 using Unsloth library\n",
        "\n",
        "### Initializing Gemma 2 model\n",
        "\n",
        "Unsloth library supports a variety of open-source LLMs including Gemma. For this notebook, you will use Gemma 2's 2b model. You can load the Gemma 2 model in Unsloth using the `FastLanguageModel` class. To know more about the other variants of Gemma 2 model provided by Unsloth, visit the [Unsloth's supported models documentation](https://docs.unsloth.ai/get-started/all-our-models)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "QmUBVEnvCDJv"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.\n",
            "🦥 Unsloth Zoo will now patch everything to make training faster!\n",
            "==((====))==  Unsloth 2024.11.11: Fast Gemma2 patching. Transformers:4.46.2.\n",
            "   \\\\   /|    GPU: Tesla T4. Max memory: 14.748 GB. Platform: Linux.\n",
            "O^O/ \\_/ \\    Torch: 2.5.1+cu121. CUDA: 7.5. CUDA Toolkit: 12.1. Triton: 3.1.0\n",
            "\\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.28.post3. FA2 = False]\n",
            " \"-____-\"     Free Apache license: http://github.com/unslothai/unsloth\n",
            "Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "82305d44082845d5a3cab38968e10ee7",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model.safetensors:   0%|          | 0.00/2.22G [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "e13d6a3234c34c6894a9b3f08f7401a0",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "generation_config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "18337511819f44bc9c9a505ce2945453",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tokenizer_config.json:   0%|          | 0.00/46.4k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "b474787293b147e6be5c7e898e82f5f2",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "035fd37b97c942b9964b5d868f80cd49",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "df7113d85a724baba5be3d1e045d16e8",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "from unsloth import FastLanguageModel\n",
        "import torch\n",
        "\n",
        "max_seq_length = 2048\n",
        "\n",
        "model, tokenizer = FastLanguageModel.from_pretrained(\n",
        "    model_name = \"unsloth/gemma-2-2b\",\n",
        "    max_seq_length = max_seq_length,\n",
        "    dtype = None,        # None for auto detection.\n",
        "    load_in_4bit = True, # Use 4bit quantization to reduce memory usage.\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xvCgS45kZFdu"
      },
      "source": [
        "### Load a dataset\n",
        "\n",
        "For this guide, you'll use an existing dataset from Hugging Face. You can replace it with your own dataset if you prefer.\n",
        "\n",
        "The dataset chosen for this guide is [**yahma/alpaca-cleaned**](https://huggingface.co/datasets/yahma/alpaca-cleaned), which is a clean version of the original Alpaca dataset by Stanford. The Alpaca dataset is a collection of over 50,000 instructions and demonstrations that can be used to fine-tune language models to better understand and follow instructions.\n",
        "\n",
        "**Credits:** **https://huggingface.co/yahma**"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "id": "FTEOSyqzZ5wx"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "67031f495ee54b22b452e3e16f489335",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "README.md:   0%|          | 0.00/11.6k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "1d85ba466f6b465d8e77dee2742addea",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "alpaca_data_cleaned.json:   0%|          | 0.00/44.3M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "41940dec1dab42a78f086efcde3c7d83",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Generating train split:   0%|          | 0/51760 [00:00<?, ? examples/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "from datasets import load_dataset\n",
        "dataset = load_dataset(\"yahma/alpaca-cleaned\", split = \"train\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HE5WGhfdcp_W"
      },
      "source": [
        "Let's look at a few samples to understand the data."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "id": "TGFAfMpZcooc"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "{'output': \"The motherboard, also known as the mainboard or system board, is the central printed circuit board in a computer. It serves as the backbone or foundation for a computer, connecting all the different components such as the CPU, RAM, storage drives, expansion cards, and peripherals. The motherboard manages communication and data transfer between these components, allowing them to work together and perform their designated tasks.\\n\\nThe motherboard also includes important circuitry such as the power regulation circuit that provides power to the different components, and the clock generator which synchronizes the operation of these components. It also contains the BIOS (basic input/output system), which is a firmware that controls the boot process and provides an interface for configuring and managing the computer's hardware. Other features on a motherboard may include built-in networking, audio, and video capabilities.\\n\\nOverall, the function of a computer motherboard is to provide a platform for the integration and operation of all the various components that make up a computer, making it an essential part of any computing system.\",\n",
              " 'input': '',\n",
              " 'instruction': 'Describe the function of a computer motherboard'}"
            ]
          },
          "execution_count": 6,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "dataset[15]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "id": "-BEG9f4XdA3o"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "summary": "{\n  \"name\": \"df_train\",\n  \"rows\": 51760,\n  \"fields\": [\n    {\n      \"column\": \"output\",\n      \"properties\": {\n        \"dtype\": \"string\",\n        \"num_unique_values\": 51498,\n        \"samples\": [\n          \"The midday sun shone high in the sky, casting golden rays upon the fields of lush green grass. In the distance, a castle stood tall and proud, its grey stone walls a testament to the power of its lord. A young peasant girl worked in the fields, her hands stained with dirt and sweat beading on her brow. Suddenly, she heard the sound of hooves approaching and turned to see a knight on horseback, his armor shining in the sun. He dismounted and approached her, offering a smile and a helping hand. And so, a friendship was forged, in the shadow of the castle.\",\n          \"Two common characteristics of mammals are:\\n1. Mammary glands: Female mammals possess mammary glands that produce milk to nourish their offspring.\\n2. Hair or fur: All mammals have some form of hair or fur on their bodies, which provides insulation and helps regulate body temperature.\",\n          \"The quadratic equation is in standard form: `ax^2 + bx + c = 0` where a = 7, b = 2 and c = 7. So given equation is already in quadratic form, which is 7x^2 + 2x + 7 = 0.\"\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"input\",\n      \"properties\": {\n        \"dtype\": \"category\",\n        \"num_unique_values\": 18363,\n        \"samples\": [\n          \"Document: This document describes the challenges faced by small business owners in the digital age, and how they can respond to these challenges in order to remain successful.\",\n          \"How to reduce the number of plastic bags used in the grocery store?\",\n          \"Suggestion: A cat obsessed with astronomy\"\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"instruction\",\n      \"properties\": {\n        \"dtype\": \"string\",\n        \"num_unique_values\": 51725,\n        \"samples\": [\n          \"Calculate the standard deviation for the following numbers.\",\n          \"Automatically insert a semicolon at the end of this line of code\",\n          \"Determine how many seconds are in a day.\"\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    }\n  ]\n}",
              "type": "dataframe",
              "variable_name": "df_train"
            },
            "application/vnd.google.colaboratory.module+javascript": "\n      import \"https://ssl.gstatic.com/colaboratory/data_table/e523c247d1e24a05/data_table.js\";\n\n      const table = window.createDataTable({\n        data: [[{\n            'v': 0,\n            'f': \"0\",\n        },\n\"1. Eat a balanced and nutritious diet: Make sure your meals are inclusive of a variety of fruits and vegetables, lean protein, whole grains, and healthy fats. This helps to provide your body with the essential nutrients to function at its best and can help prevent chronic diseases.\\n\\n2. Engage in regular physical activity: Exercise is crucial for maintaining strong bones, muscles, and cardiovascular health. Aim for at least 150 minutes of moderate aerobic exercise or 75 minutes of vigorous exercise each week.\\n\\n3. Get enough sleep: Getting enough quality sleep is crucial for physical and mental well-being. It helps to regulate mood, improve cognitive function, and supports healthy growth and immune function. Aim for 7-9 hours of sleep each night.\",\n\"\",\n\"Give three tips for staying healthy.\"],\n [{\n            'v': 1,\n            'f': \"1\",\n        },\n\"The three primary colors are red, blue, and yellow. These colors are called primary because they cannot be created by mixing other colors and all other colors can be made by combining them in various proportions. In the additive color system, used for light, the primary colors are red, green, and blue (RGB).\",\n\"\",\n\"What are the three primary colors?\"],\n [{\n            'v': 2,\n            'f': \"2\",\n        },\n\"An atom is the basic building block of all matter and is made up of three types of particles: protons, neutrons, and electrons. The structure of an atom can be described as a nucleus at the center surrounded by a cloud of electrons.\\n\\nThe nucleus of an atom is made up of protons and neutrons. Protons are positively charged particles and neutrons are neutral particles with no charge. Both of these particles are located in the nucleus of the atom, which is at the center of the atom and contains most of the atom's mass.\\n\\nSurrounding the nucleus of the atom is a cloud of electrons. Electrons are negatively charged particles that are in constant motion around the nucleus. The electron cloud is divided into shells or orbitals, and each shell can hold a certain number of electrons. The number of electrons in the outermost shell, called the valence shell, determines the chemical properties of the atom. \\n\\nIn a neutral atom, the number of protons in the nucleus is equal to the number of electrons in the electron cloud, so the positive and negative charges balance out and the atom has no overall charge. The number of protons, also called the atomic number, determines what element the atom is.\",\n\"\",\n\"Describe the structure of an atom.\"],\n [{\n            'v': 3,\n            'f': \"3\",\n        },\n\"There are several ways to reduce air pollution, including:\\n\\n1. Reduce energy consumption: By conserving energy, we reduce the amount of pollution emitted from power plants. Some ways to achieve this include: using energy-efficient appliances and lighting, insulating homes, and shutting off lights and electronics when not in use.\\n\\n2. Use public transportation or carpool: Cars and trucks are major sources of pollution. By using public transportation, carpooling or biking, you can reduce the number of cars on the road, which in turn reduces emissions.\\n\\n3. Be responsible with waste: When waste is burned or decomposed, they can release harmful gases into the air. Dispose of hazardous waste correctly and recycle materials like plastics, paper, and glass when possible.\\n\\n4. Support clean energy sources: Renewable energy sources like solar and wind power generate little or no pollution. Support clean energy companies and projects, and consider installing renewable energy systems in your home.\\n\\n5. Drive efficiently: If you need to drive, make sure your vehicle is well maintained and drive efficiently by avoiding rapid acceleration and braking, and reducing idling.\\n\\n6. Be mindful of product choices: Some products and practices release more pollution than others. Choose products with eco-friendly packaging and try to minimize your use of aerosol sprays, which can release harmful chemicals into the air.\",\n\"\",\n\"How can we reduce air pollution?\"],\n [{\n            'v': 4,\n            'f': \"4\",\n        },\n\"I had to make a difficult decision when I was working as a project manager at a construction company. I was in charge of a project that needed to be completed by a certain date in order to meet the client\\u2019s expectations. However, due to unexpected delays, we were not able to meet the deadline and so I had to make a difficult decision. I decided to extend the deadline, but I had to stretch the team\\u2019s resources even further and increase the budget. Although it was a risky decision, I ultimately decided to go ahead with it to ensure that the project was completed on time and that the client\\u2019s expectations were met. The project was eventually successfully completed and this was seen as a testament to my leadership and decision-making abilities.\",\n\"\",\n\"Pretend you are a project manager of a construction company. Describe a time when you had to make a difficult decision.\"]],\n        columns: [[\"number\", \"index\"], [\"string\", \"output\"], [\"string\", \"input\"], [\"string\", \"instruction\"]],\n        columnOptions: [{\"width\": \"1px\", \"className\": \"index_column\"}],\n        rowsPerPage: 25,\n        helpUrl: \"https://colab.research.google.com/notebooks/data_table.ipynb\",\n        suppressOutputScrolling: true,\n        minimumWidth: undefined,\n      });\n\n      function appendQuickchartButton(parentElement) {\n        let quickchartButtonContainerElement = document.createElement('div');\n        quickchartButtonContainerElement.innerHTML = `\n<div id=\"df-06abcf28-7af1-4949-ae75-c513528cf6ae\">\n  <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-06abcf28-7af1-4949-ae75-c513528cf6ae')\"\n            title=\"Suggest charts\"\n            style=\"display:none;\">\n    \n<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n     width=\"24px\">\n    <g>\n        <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n    </g>\n</svg>\n  </button>\n  \n<style>\n  .colab-df-quickchart {\n      --bg-color: #E8F0FE;\n      --fill-color: #1967D2;\n      --hover-bg-color: #E2EBFA;\n      --hover-fill-color: #174EA6;\n      --disabled-fill-color: #AAA;\n      --disabled-bg-color: #DDD;\n  }\n\n  [theme=dark] .colab-df-quickchart {\n      --bg-color: #3B4455;\n      --fill-color: #D2E3FC;\n      --hover-bg-color: #434B5C;\n      --hover-fill-color: #FFFFFF;\n      --disabled-bg-color: #3B4455;\n      --disabled-fill-color: #666;\n  }\n\n  .colab-df-quickchart {\n    background-color: var(--bg-color);\n    border: none;\n    border-radius: 50%;\n    cursor: pointer;\n    display: none;\n    fill: var(--fill-color);\n    height: 32px;\n    padding: 0;\n    width: 32px;\n  }\n\n  .colab-df-quickchart:hover {\n    background-color: var(--hover-bg-color);\n    box-shadow: 0 1px 2px rgba(60, 64, 67, 0.3), 0 1px 3px 1px rgba(60, 64, 67, 0.15);\n    fill: var(--button-hover-fill-color);\n  }\n\n  .colab-df-quickchart-complete:disabled,\n  .colab-df-quickchart-complete:disabled:hover {\n    background-color: var(--disabled-bg-color);\n    fill: var(--disabled-fill-color);\n    box-shadow: none;\n  }\n\n  .colab-df-spinner {\n    border: 2px solid var(--fill-color);\n    border-color: transparent;\n    border-bottom-color: var(--fill-color);\n    animation:\n      spin 1s steps(1) infinite;\n  }\n\n  @keyframes spin {\n    0% {\n      border-color: transparent;\n      border-bottom-color: var(--fill-color);\n      border-left-color: var(--fill-color);\n    }\n    20% {\n      border-color: transparent;\n      border-left-color: var(--fill-color);\n      border-top-color: var(--fill-color);\n    }\n    30% {\n      border-color: transparent;\n      border-left-color: var(--fill-color);\n      border-top-color: var(--fill-color);\n      border-right-color: var(--fill-color);\n    }\n    40% {\n      border-color: transparent;\n      border-right-color: var(--fill-color);\n      border-top-color: var(--fill-color);\n    }\n    60% {\n      border-color: transparent;\n      border-right-color: var(--fill-color);\n    }\n    80% {\n      border-color: transparent;\n      border-right-color: var(--fill-color);\n      border-bottom-color: var(--fill-color);\n    }\n    90% {\n      border-color: transparent;\n      border-bottom-color: var(--fill-color);\n    }\n  }\n</style>\n\n  <script>\n    async function quickchart(key) {\n      const quickchartButtonEl =\n        document.querySelector('#' + key + ' button');\n      quickchartButtonEl.disabled = true;  // To prevent multiple clicks.\n      quickchartButtonEl.classList.add('colab-df-spinner');\n      try {\n        const charts = await google.colab.kernel.invokeFunction(\n            'suggestCharts', [key], {});\n      } catch (error) {\n        console.error('Error during call to suggestCharts:', error);\n      }\n      quickchartButtonEl.classList.remove('colab-df-spinner');\n      quickchartButtonEl.classList.add('colab-df-quickchart-complete');\n    }\n    (() => {\n      let quickchartButtonEl =\n        document.querySelector('#df-06abcf28-7af1-4949-ae75-c513528cf6ae button');\n      quickchartButtonEl.style.display =\n        google.colab.kernel.accessAllowed ? 'block' : 'none';\n    })();\n  </script>\n</div>`;\n        parentElement.appendChild(quickchartButtonContainerElement);\n      }\n\n      appendQuickchartButton(table);\n    ",
            "text/html": [
              "\n",
              "  <div id=\"df-e0c9e8a7-cf25-44da-af8b-31c98cac8e16\" class=\"colab-df-container\">\n",
              "    <div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>output</th>\n",
              "      <th>input</th>\n",
              "      <th>instruction</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>1. Eat a balanced and nutritious diet: Make su...</td>\n",
              "      <td></td>\n",
              "      <td>Give three tips for staying healthy.</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>The three primary colors are red, blue, and ye...</td>\n",
              "      <td></td>\n",
              "      <td>What are the three primary colors?</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>An atom is the basic building block of all mat...</td>\n",
              "      <td></td>\n",
              "      <td>Describe the structure of an atom.</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>There are several ways to reduce air pollution...</td>\n",
              "      <td></td>\n",
              "      <td>How can we reduce air pollution?</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>I had to make a difficult decision when I was ...</td>\n",
              "      <td></td>\n",
              "      <td>Pretend you are a project manager of a constru...</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>\n",
              "    <div class=\"colab-df-buttons\">\n",
              "\n",
              "  <div class=\"colab-df-container\">\n",
              "    <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-e0c9e8a7-cf25-44da-af8b-31c98cac8e16')\"\n",
              "            title=\"Convert this dataframe to an interactive table.\"\n",
              "            style=\"display:none;\">\n",
              "\n",
              "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n",
              "    <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n",
              "  </svg>\n",
              "    </button>\n",
              "\n",
              "  <style>\n",
              "    .colab-df-container {\n",
              "      display:flex;\n",
              "      gap: 12px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert {\n",
              "      background-color: #E8F0FE;\n",
              "      border: none;\n",
              "      border-radius: 50%;\n",
              "      cursor: pointer;\n",
              "      display: none;\n",
              "      fill: #1967D2;\n",
              "      height: 32px;\n",
              "      padding: 0 0 0 0;\n",
              "      width: 32px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert:hover {\n",
              "      background-color: #E2EBFA;\n",
              "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "      fill: #174EA6;\n",
              "    }\n",
              "\n",
              "    .colab-df-buttons div {\n",
              "      margin-bottom: 4px;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert {\n",
              "      background-color: #3B4455;\n",
              "      fill: #D2E3FC;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert:hover {\n",
              "      background-color: #434B5C;\n",
              "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
              "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
              "      fill: #FFFFFF;\n",
              "    }\n",
              "  </style>\n",
              "\n",
              "    <script>\n",
              "      const buttonEl =\n",
              "        document.querySelector('#df-e0c9e8a7-cf25-44da-af8b-31c98cac8e16 button.colab-df-convert');\n",
              "      buttonEl.style.display =\n",
              "        google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "\n",
              "      async function convertToInteractive(key) {\n",
              "        const element = document.querySelector('#df-e0c9e8a7-cf25-44da-af8b-31c98cac8e16');\n",
              "        const dataTable =\n",
              "          await google.colab.kernel.invokeFunction('convertToInteractive',\n",
              "                                                    [key], {});\n",
              "        if (!dataTable) return;\n",
              "\n",
              "        const docLinkHtml = 'Like what you see? Visit the ' +\n",
              "          '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
              "          + ' to learn more about interactive tables.';\n",
              "        element.innerHTML = '';\n",
              "        dataTable['output_type'] = 'display_data';\n",
              "        await google.colab.output.renderOutput(dataTable, element);\n",
              "        const docLink = document.createElement('div');\n",
              "        docLink.innerHTML = docLinkHtml;\n",
              "        element.appendChild(docLink);\n",
              "      }\n",
              "    </script>\n",
              "  </div>\n",
              "\n",
              "\n",
              "<div id=\"df-4e545d4d-b5c5-4d2e-b0c9-4511c9049601\">\n",
              "  <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-4e545d4d-b5c5-4d2e-b0c9-4511c9049601')\"\n",
              "            title=\"Suggest charts\"\n",
              "            style=\"display:none;\">\n",
              "\n",
              "<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
              "     width=\"24px\">\n",
              "    <g>\n",
              "        <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n",
              "    </g>\n",
              "</svg>\n",
              "  </button>\n",
              "\n",
              "<style>\n",
              "  .colab-df-quickchart {\n",
              "      --bg-color: #E8F0FE;\n",
              "      --fill-color: #1967D2;\n",
              "      --hover-bg-color: #E2EBFA;\n",
              "      --hover-fill-color: #174EA6;\n",
              "      --disabled-fill-color: #AAA;\n",
              "      --disabled-bg-color: #DDD;\n",
              "  }\n",
              "\n",
              "  [theme=dark] .colab-df-quickchart {\n",
              "      --bg-color: #3B4455;\n",
              "      --fill-color: #D2E3FC;\n",
              "      --hover-bg-color: #434B5C;\n",
              "      --hover-fill-color: #FFFFFF;\n",
              "      --disabled-bg-color: #3B4455;\n",
              "      --disabled-fill-color: #666;\n",
              "  }\n",
              "\n",
              "  .colab-df-quickchart {\n",
              "    background-color: var(--bg-color);\n",
              "    border: none;\n",
              "    border-radius: 50%;\n",
              "    cursor: pointer;\n",
              "    display: none;\n",
              "    fill: var(--fill-color);\n",
              "    height: 32px;\n",
              "    padding: 0;\n",
              "    width: 32px;\n",
              "  }\n",
              "\n",
              "  .colab-df-quickchart:hover {\n",
              "    background-color: var(--hover-bg-color);\n",
              "    box-shadow: 0 1px 2px rgba(60, 64, 67, 0.3), 0 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "    fill: var(--button-hover-fill-color);\n",
              "  }\n",
              "\n",
              "  .colab-df-quickchart-complete:disabled,\n",
              "  .colab-df-quickchart-complete:disabled:hover {\n",
              "    background-color: var(--disabled-bg-color);\n",
              "    fill: var(--disabled-fill-color);\n",
              "    box-shadow: none;\n",
              "  }\n",
              "\n",
              "  .colab-df-spinner {\n",
              "    border: 2px solid var(--fill-color);\n",
              "    border-color: transparent;\n",
              "    border-bottom-color: var(--fill-color);\n",
              "    animation:\n",
              "      spin 1s steps(1) infinite;\n",
              "  }\n",
              "\n",
              "  @keyframes spin {\n",
              "    0% {\n",
              "      border-color: transparent;\n",
              "      border-bottom-color: var(--fill-color);\n",
              "      border-left-color: var(--fill-color);\n",
              "    }\n",
              "    20% {\n",
              "      border-color: transparent;\n",
              "      border-left-color: var(--fill-color);\n",
              "      border-top-color: var(--fill-color);\n",
              "    }\n",
              "    30% {\n",
              "      border-color: transparent;\n",
              "      border-left-color: var(--fill-color);\n",
              "      border-top-color: var(--fill-color);\n",
              "      border-right-color: var(--fill-color);\n",
              "    }\n",
              "    40% {\n",
              "      border-color: transparent;\n",
              "      border-right-color: var(--fill-color);\n",
              "      border-top-color: var(--fill-color);\n",
              "    }\n",
              "    60% {\n",
              "      border-color: transparent;\n",
              "      border-right-color: var(--fill-color);\n",
              "    }\n",
              "    80% {\n",
              "      border-color: transparent;\n",
              "      border-right-color: var(--fill-color);\n",
              "      border-bottom-color: var(--fill-color);\n",
              "    }\n",
              "    90% {\n",
              "      border-color: transparent;\n",
              "      border-bottom-color: var(--fill-color);\n",
              "    }\n",
              "  }\n",
              "</style>\n",
              "\n",
              "  <script>\n",
              "    async function quickchart(key) {\n",
              "      const quickchartButtonEl =\n",
              "        document.querySelector('#' + key + ' button');\n",
              "      quickchartButtonEl.disabled = true;  // To prevent multiple clicks.\n",
              "      quickchartButtonEl.classList.add('colab-df-spinner');\n",
              "      try {\n",
              "        const charts = await google.colab.kernel.invokeFunction(\n",
              "            'suggestCharts', [key], {});\n",
              "      } catch (error) {\n",
              "        console.error('Error during call to suggestCharts:', error);\n",
              "      }\n",
              "      quickchartButtonEl.classList.remove('colab-df-spinner');\n",
              "      quickchartButtonEl.classList.add('colab-df-quickchart-complete');\n",
              "    }\n",
              "    (() => {\n",
              "      let quickchartButtonEl =\n",
              "        document.querySelector('#df-4e545d4d-b5c5-4d2e-b0c9-4511c9049601 button');\n",
              "      quickchartButtonEl.style.display =\n",
              "        google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "    })();\n",
              "  </script>\n",
              "</div>\n",
              "\n",
              "    </div>\n",
              "  </div>\n"
            ],
            "text/plain": [
              "                                              output input  \\\n",
              "0  1. Eat a balanced and nutritious diet: Make su...         \n",
              "1  The three primary colors are red, blue, and ye...         \n",
              "2  An atom is the basic building block of all mat...         \n",
              "3  There are several ways to reduce air pollution...         \n",
              "4  I had to make a difficult decision when I was ...         \n",
              "\n",
              "                                         instruction  \n",
              "0               Give three tips for staying healthy.  \n",
              "1                 What are the three primary colors?  \n",
              "2                 Describe the structure of an atom.  \n",
              "3                   How can we reduce air pollution?  \n",
              "4  Pretend you are a project manager of a constru...  "
            ]
          },
          "execution_count": 7,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "from google.colab import data_table\n",
        "import pandas as pd\n",
        "\n",
        "# Enable interactive DataFrame display\n",
        "data_table.enable_dataframe_formatter()\n",
        "\n",
        "# Convert the 'train' split to a Pandas DataFrame\n",
        "df_train = pd.DataFrame(dataset)\n",
        "\n",
        "# Select the first 5 rows\n",
        "df_train.head(5)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MAGUvIjReACW"
      },
      "source": [
        "### Set the prompt template\n",
        "\n",
        "Here you will define the Alpaca prompt template. This template has 3 sections:\n",
        "- Instruction\n",
        "- Input\n",
        "- Response"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "id": "LjY75GoYUCB8"
      },
      "outputs": [],
      "source": [
        "alpaca_prompt_template = \"\"\"Below is an instruction that describes a task, paired with an\n",
        "input that provides further context. Write a response that appropriately\n",
        "completes the request.\n",
        "\n",
        "### Instruction:\n",
        "{}\n",
        "\n",
        "### Input:\n",
        "{}\n",
        "\n",
        "### Response:\n",
        "{}\"\"\"\n",
        "\n",
        "EOS_TOKEN = tokenizer.eos_token # EOS_TOKEN is necessary."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RlLb28SSef_M"
      },
      "source": [
        "### Define the formatting function\n",
        "\n",
        "The formatting function applies the template created above to each row in the dataset and converts it into a format suited for training."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "id": "1UYqbBh5ed_e"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "cef957d88bd9482db3f624c69b747d69",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Map:   0%|          | 0/51760 [00:00<?, ? examples/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "def formatting_prompts_func(examples):\n",
        "    instructions = examples[\"instruction\"]\n",
        "    inputs       = examples[\"input\"]\n",
        "    outputs      = examples[\"output\"]\n",
        "    # EOS_TOKEN is necessary, otherwise your generation will go on forever!\n",
        "    texts = [alpaca_prompt_template.format(instruction, input, output) + EOS_TOKEN\n",
        "                                  for instruction, input, output in\n",
        "                                  zip(instructions, inputs, outputs)]\n",
        "    return { \"text\" : texts, }\n",
        "pass\n",
        "\n",
        "dataset = dataset.map(formatting_prompts_func, batched = True,)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SXd9bTZd1aaL"
      },
      "source": [
        "### Set LoRA configuration\n",
        "\n",
        "LoRA (Low-Rank Adaptation) allows for efficient fine-tuning by adapting only a subset of model parameters.\n",
        "\n",
        "Here, you set the following parameters:\n",
        "- `r` to 16, which controls the rank of the adaptation matrices.\n",
        "- `lora_alpha` to 16 for scaling.\n",
        "- `lora_dropout` to 0 since it is optimized.\n",
        "\n",
        "To know more about LoRA parameters and their effects, check out the [LoRA parameters encyclopedia](https://github.com/unslothai/unsloth/wiki#lora-parameters-encyclopedia)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "id": "6bZsfBuZDeCL"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "Unsloth 2024.11.11 patched 26 layers with 26 QKV layers, 26 O layers and 26 MLP layers.\n"
          ]
        }
      ],
      "source": [
        "model = FastLanguageModel.get_peft_model(\n",
        "    model,\n",
        "    r = 16, # LoRA attention dimension\n",
        "    target_modules = [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
        "                      \"gate_proj\", \"up_proj\", \"down_proj\",],\n",
        "    lora_alpha = 16,  # Alpha parameter for LoRA scaling\n",
        "    lora_dropout = 0, # Supports any, but = 0 is optimized\n",
        "    bias = \"none\",    # Supports any, but = \"none\" is optimized\n",
        "    use_gradient_checkpointing = \"unsloth\", # True or \"unsloth\" for very long context\n",
        "    random_state = 3407,\n",
        "    use_rslora = False,  # Rank stabilized LoRA\n",
        "    loftq_config = None, # LoRA-Fine-Tuning-Aware Quantization\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vITh0KVJ10qX"
      },
      "source": [
        "### Set training configuration\n",
        "\n",
        "Set up the training arguments that define how the model will be trained.\n",
        "\n",
        "Here, you'll define the following parameters:\n",
        "\n",
        "- For training and evaluation:\n",
        "  - `output directory`\n",
        "  - `max steps`\n",
        "  - `batch sizes`\n",
        "\n",
        "- To optimize the training process:\n",
        "  - `learning rate`\n",
        "  - `optimizer`\n",
        "  - `learning rate scheduler`\n",
        "\n",
        "**Note:** `max_steps` is set as 60 steps to speed things up, but you can set `num_train_epochs=1` for a full run."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "id": "i__N744Jgz5X"
      },
      "outputs": [],
      "source": [
        "from transformers import TrainingArguments\n",
        "\n",
        "training_args = TrainingArguments(\n",
        "    per_device_train_batch_size = 2,\n",
        "    gradient_accumulation_steps = 4,\n",
        "    warmup_steps = 5,\n",
        "    # num_train_epochs = 1, # Set this for 1 full training run.\n",
        "    max_steps = 60,\n",
        "    learning_rate = 2e-4,\n",
        "    fp16 = not torch.cuda.is_bf16_supported(),\n",
        "    bf16 = torch.cuda.is_bf16_supported(),\n",
        "    logging_steps = 1,\n",
        "    optim = \"adamw_8bit\",\n",
        "    weight_decay = 0.01,\n",
        "    lr_scheduler_type = \"linear\",\n",
        "    seed = 42,\n",
        "    output_dir = \"outputs\",\n",
        "    report_to = \"none\",\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "idAEIeSQ3xdS"
      },
      "source": [
        "<a name=\"Train\"></a>\n",
        "### Train the model\n",
        "\n",
        "[Huggingface's TRL](https://huggingface.co/docs/trl/index) offers a user-friendly API for building SFT models and training them on your dataset with just a few lines of code. Here you will use Huggingface TRL's `SFTTrainer` class to train the model. This class inherits from the `Trainer` class available in the Transformers library, but is specifically optimized for supervised fine-tuning (instruction tuning). Read more about SFFTrainer from the [official TRL SFT docs](https://huggingface.co/docs/trl/sft_trainer)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {
        "id": "95_Nn-89DhsL"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "cb9944ca20b34cf6a7f33f8ef38142bc",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Map (num_proc=2):   0%|          | 0/51760 [00:00<?, ? examples/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "max_steps is given, it will override any value given in num_train_epochs\n"
          ]
        }
      ],
      "source": [
        "from trl import SFTTrainer\n",
        "from transformers import TrainingArguments\n",
        "from unsloth import is_bfloat16_supported\n",
        "\n",
        "trainer = SFTTrainer(\n",
        "    model = model,\n",
        "    tokenizer = tokenizer,\n",
        "    train_dataset = dataset,\n",
        "    dataset_text_field = \"text\",\n",
        "    max_seq_length = max_seq_length,\n",
        "    dataset_num_proc = 2,\n",
        "    # Setting packing as False can speed up training five times\n",
        "    # for short sequences.\n",
        "    packing = False,\n",
        "    args = training_args\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "10U7GGxLkee0"
      },
      "source": [
        "Now, let's start the fine-tuning process by calling `trainer.train()`, which uses `SFTTrainer` to handle the training loop, including data loading, forward and backward passes, and optimizer steps, all configured according to the settings you've provided."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "metadata": {
        "id": "yqxqAZ7KJ4oL"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "==((====))==  Unsloth - 2x faster free finetuning | Num GPUs = 1\n",
            "   \\\\   /|    Num examples = 51,760 | Num Epochs = 1\n",
            "O^O/ \\_/ \\    Batch size per device = 2 | Gradient Accumulation steps = 4\n",
            "\\        /    Total batch size = 8 | Total steps = 60\n",
            " \"-____-\"     Number of trainable parameters = 20,766,720\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='60' max='60' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [60/60 03:30, Epoch 0/1]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Step</th>\n",
              "      <th>Training Loss</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>1</td>\n",
              "      <td>1.874800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2</td>\n",
              "      <td>2.086800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>3</td>\n",
              "      <td>2.080000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>4</td>\n",
              "      <td>2.015800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>5</td>\n",
              "      <td>1.535900</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>6</td>\n",
              "      <td>1.430500</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>7</td>\n",
              "      <td>1.328700</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>8</td>\n",
              "      <td>1.366300</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>9</td>\n",
              "      <td>1.179300</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>10</td>\n",
              "      <td>1.110300</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>11</td>\n",
              "      <td>1.037000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>12</td>\n",
              "      <td>1.119800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>13</td>\n",
              "      <td>1.074500</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>14</td>\n",
              "      <td>1.005300</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>15</td>\n",
              "      <td>0.982400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>16</td>\n",
              "      <td>0.988500</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>17</td>\n",
              "      <td>1.038500</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>18</td>\n",
              "      <td>1.105900</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>19</td>\n",
              "      <td>1.040700</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>20</td>\n",
              "      <td>0.927900</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>21</td>\n",
              "      <td>1.098500</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>22</td>\n",
              "      <td>1.001200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>23</td>\n",
              "      <td>1.172400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>24</td>\n",
              "      <td>1.023400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>25</td>\n",
              "      <td>0.861500</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>26</td>\n",
              "      <td>1.052300</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>27</td>\n",
              "      <td>1.025700</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>28</td>\n",
              "      <td>0.868600</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>29</td>\n",
              "      <td>0.941500</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>30</td>\n",
              "      <td>0.942400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>31</td>\n",
              "      <td>0.887800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>32</td>\n",
              "      <td>0.930300</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>33</td>\n",
              "      <td>1.059300</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>34</td>\n",
              "      <td>0.889700</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>35</td>\n",
              "      <td>1.032900</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>36</td>\n",
              "      <td>0.978600</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>37</td>\n",
              "      <td>0.952000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>38</td>\n",
              "      <td>0.849100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>39</td>\n",
              "      <td>0.706200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>40</td>\n",
              "      <td>1.116000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>41</td>\n",
              "      <td>1.135700</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>42</td>\n",
              "      <td>0.976800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>43</td>\n",
              "      <td>0.998700</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>44</td>\n",
              "      <td>1.102600</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>45</td>\n",
              "      <td>0.987900</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>46</td>\n",
              "      <td>1.053700</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>47</td>\n",
              "      <td>0.904800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>48</td>\n",
              "      <td>1.109300</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>49</td>\n",
              "      <td>0.984200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>50</td>\n",
              "      <td>0.932200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>51</td>\n",
              "      <td>0.856400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>52</td>\n",
              "      <td>0.994300</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>53</td>\n",
              "      <td>0.916800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>54</td>\n",
              "      <td>0.827000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>55</td>\n",
              "      <td>0.860700</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>56</td>\n",
              "      <td>0.832600</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>57</td>\n",
              "      <td>1.054400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>58</td>\n",
              "      <td>1.029600</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>59</td>\n",
              "      <td>0.980800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>60</td>\n",
              "      <td>0.884100</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "trainer_stats = trainer.train()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "d_SJ-iTxlH-c"
      },
      "source": [
        "### Save the model locally\n",
        "\n",
        "After training is complete, save the fine-tuned model by calling `save_pretrained(new_model)`. This saves the model weights and configuration files to the directory specified by `new_model` (**gemma_ft_unsloth**). You can reload and use the fine-tuned model later for inference or further training."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "metadata": {
        "id": "fuLpAtt7lQdU"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "('gemma_ft_unsloth/tokenizer_config.json',\n",
              " 'gemma_ft_unsloth/special_tokens_map.json',\n",
              " 'gemma_ft_unsloth/tokenizer.model',\n",
              " 'gemma_ft_unsloth/added_tokens.json',\n",
              " 'gemma_ft_unsloth/tokenizer.json')"
            ]
          },
          "execution_count": 14,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "new_model = \"gemma_ft_unsloth\"\n",
        "model.save_pretrained(new_model)\n",
        "tokenizer.save_pretrained(new_model)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cz41b3rDNegs"
      },
      "source": [
        "### Push to Hugging Face Hub\n",
        "\n",
        "You can use the model's `push_to_hub` method to upload the model to Hugging Face Hub.\n",
        "\n",
        "**Note**: In the following code snippets, replace \"your_hf_username\" to your Hugging Face username."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "H4d-nreYNbYb"
      },
      "outputs": [],
      "source": [
        "# Push the trained model to Hub\n",
        "model.push_to_hub(\"your_hf_username/gemma_ft_unsloth\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SqQFyg00XyjO"
      },
      "source": [
        "If you only want to save the Lora adapters, uncomment and run the code below."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 16,
      "metadata": {
        "id": "-fB6XaUbYJFo"
      },
      "outputs": [],
      "source": [
        "#model.push_to_hub_merged(\"your_hf_username/gemma_ft_lora\", tokenizer,\n",
        "#                         save_method = \"lora\", token = os.environ[\"HF_TOKEN\"])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Xd_V_x6fYN8Y"
      },
      "source": [
        "If you want to save the model as GGUF, uncomment and run the code below. To check the different quantization methods supported by Unsloth, visit\n",
        "[Unsloth wiki's Save as GGUF section](https://github.com/unslothai/unsloth/wiki#saving-to-gguf)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 17,
      "metadata": {
        "id": "VzU1I78wYts_"
      },
      "outputs": [],
      "source": [
        "#model.push_to_hub_gguf(\"your_hf_username/gemma_ft_q4_k_m\",\n",
        "#                       tokenizer, quantization_method = \"q4_k_m\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xfyWrHUYY3AP"
      },
      "source": [
        "To know more about the different formats available in Unsloth, check out [Unsloth wiki's saving models section](https://github.com/unslothai/unsloth/wiki#saving-models-to-16bit-for-vllm)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zM-QLe32qzlF"
      },
      "source": [
        "## Inference\n",
        "\n",
        "### Prompt using the newly fine-tuned model\n",
        "\n",
        "\n",
        "Now that you've finally fine-tuned your custom Gemma model, let's reload the LoRA adapter weights and tokenizer to finally prompt using it and also verify if it's really working as intended."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 18,
      "metadata": {
        "id": "vzKkxC7arjxP"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "==((====))==  Unsloth 2024.11.11: Fast Gemma2 patching. Transformers:4.46.2.\n",
            "   \\\\   /|    GPU: Tesla T4. Max memory: 14.748 GB. Platform: Linux.\n",
            "O^O/ \\_/ \\    Torch: 2.5.1+cu121. CUDA: 7.5. CUDA Toolkit: 12.1. Triton: 3.1.0\n",
            "\\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.28.post3. FA2 = False]\n",
            " \"-____-\"     Free Apache license: http://github.com/unslothai/unsloth\n",
            "Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!\n"
          ]
        },
        {
          "data": {
            "text/plain": [
              "PeftModelForCausalLM(\n",
              "  (base_model): LoraModel(\n",
              "    (model): Gemma2ForCausalLM(\n",
              "      (model): Gemma2Model(\n",
              "        (embed_tokens): Embedding(256000, 2304, padding_idx=0)\n",
              "        (layers): ModuleList(\n",
              "          (0-25): 26 x Gemma2DecoderLayer(\n",
              "            (self_attn): Gemma2Attention(\n",
              "              (q_proj): lora.Linear4bit(\n",
              "                (base_layer): Linear4bit(in_features=2304, out_features=2048, bias=False)\n",
              "                (lora_dropout): ModuleDict(\n",
              "                  (default): Identity()\n",
              "                )\n",
              "                (lora_A): ModuleDict(\n",
              "                  (default): Linear(in_features=2304, out_features=16, bias=False)\n",
              "                )\n",
              "                (lora_B): ModuleDict(\n",
              "                  (default): Linear(in_features=16, out_features=2048, bias=False)\n",
              "                )\n",
              "                (lora_embedding_A): ParameterDict()\n",
              "                (lora_embedding_B): ParameterDict()\n",
              "                (lora_magnitude_vector): ModuleDict()\n",
              "              )\n",
              "              (k_proj): lora.Linear4bit(\n",
              "                (base_layer): Linear4bit(in_features=2304, out_features=1024, bias=False)\n",
              "                (lora_dropout): ModuleDict(\n",
              "                  (default): Identity()\n",
              "                )\n",
              "                (lora_A): ModuleDict(\n",
              "                  (default): Linear(in_features=2304, out_features=16, bias=False)\n",
              "                )\n",
              "                (lora_B): ModuleDict(\n",
              "                  (default): Linear(in_features=16, out_features=1024, bias=False)\n",
              "                )\n",
              "                (lora_embedding_A): ParameterDict()\n",
              "                (lora_embedding_B): ParameterDict()\n",
              "                (lora_magnitude_vector): ModuleDict()\n",
              "              )\n",
              "              (v_proj): lora.Linear4bit(\n",
              "                (base_layer): Linear4bit(in_features=2304, out_features=1024, bias=False)\n",
              "                (lora_dropout): ModuleDict(\n",
              "                  (default): Identity()\n",
              "                )\n",
              "                (lora_A): ModuleDict(\n",
              "                  (default): Linear(in_features=2304, out_features=16, bias=False)\n",
              "                )\n",
              "                (lora_B): ModuleDict(\n",
              "                  (default): Linear(in_features=16, out_features=1024, bias=False)\n",
              "                )\n",
              "                (lora_embedding_A): ParameterDict()\n",
              "                (lora_embedding_B): ParameterDict()\n",
              "                (lora_magnitude_vector): ModuleDict()\n",
              "              )\n",
              "              (o_proj): lora.Linear4bit(\n",
              "                (base_layer): Linear4bit(in_features=2048, out_features=2304, bias=False)\n",
              "                (lora_dropout): ModuleDict(\n",
              "                  (default): Identity()\n",
              "                )\n",
              "                (lora_A): ModuleDict(\n",
              "                  (default): Linear(in_features=2048, out_features=16, bias=False)\n",
              "                )\n",
              "                (lora_B): ModuleDict(\n",
              "                  (default): Linear(in_features=16, out_features=2304, bias=False)\n",
              "                )\n",
              "                (lora_embedding_A): ParameterDict()\n",
              "                (lora_embedding_B): ParameterDict()\n",
              "                (lora_magnitude_vector): ModuleDict()\n",
              "              )\n",
              "              (rotary_emb): GemmaFixedRotaryEmbedding()\n",
              "            )\n",
              "            (mlp): Gemma2MLP(\n",
              "              (gate_proj): lora.Linear4bit(\n",
              "                (base_layer): Linear4bit(in_features=2304, out_features=9216, bias=False)\n",
              "                (lora_dropout): ModuleDict(\n",
              "                  (default): Identity()\n",
              "                )\n",
              "                (lora_A): ModuleDict(\n",
              "                  (default): Linear(in_features=2304, out_features=16, bias=False)\n",
              "                )\n",
              "                (lora_B): ModuleDict(\n",
              "                  (default): Linear(in_features=16, out_features=9216, bias=False)\n",
              "                )\n",
              "                (lora_embedding_A): ParameterDict()\n",
              "                (lora_embedding_B): ParameterDict()\n",
              "                (lora_magnitude_vector): ModuleDict()\n",
              "              )\n",
              "              (up_proj): lora.Linear4bit(\n",
              "                (base_layer): Linear4bit(in_features=2304, out_features=9216, bias=False)\n",
              "                (lora_dropout): ModuleDict(\n",
              "                  (default): Identity()\n",
              "                )\n",
              "                (lora_A): ModuleDict(\n",
              "                  (default): Linear(in_features=2304, out_features=16, bias=False)\n",
              "                )\n",
              "                (lora_B): ModuleDict(\n",
              "                  (default): Linear(in_features=16, out_features=9216, bias=False)\n",
              "                )\n",
              "                (lora_embedding_A): ParameterDict()\n",
              "                (lora_embedding_B): ParameterDict()\n",
              "                (lora_magnitude_vector): ModuleDict()\n",
              "              )\n",
              "              (down_proj): lora.Linear4bit(\n",
              "                (base_layer): Linear4bit(in_features=9216, out_features=2304, bias=False)\n",
              "                (lora_dropout): ModuleDict(\n",
              "                  (default): Identity()\n",
              "                )\n",
              "                (lora_A): ModuleDict(\n",
              "                  (default): Linear(in_features=9216, out_features=16, bias=False)\n",
              "                )\n",
              "                (lora_B): ModuleDict(\n",
              "                  (default): Linear(in_features=16, out_features=2304, bias=False)\n",
              "                )\n",
              "                (lora_embedding_A): ParameterDict()\n",
              "                (lora_embedding_B): ParameterDict()\n",
              "                (lora_magnitude_vector): ModuleDict()\n",
              "              )\n",
              "              (act_fn): PytorchGELUTanh()\n",
              "            )\n",
              "            (input_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)\n",
              "            (pre_feedforward_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)\n",
              "            (post_feedforward_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)\n",
              "            (post_attention_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)\n",
              "          )\n",
              "        )\n",
              "        (norm): Gemma2RMSNorm((2304,), eps=1e-06)\n",
              "      )\n",
              "      (lm_head): Linear(in_features=2304, out_features=256000, bias=False)\n",
              "    )\n",
              "  )\n",
              ")"
            ]
          },
          "execution_count": 18,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "model, tokenizer = FastLanguageModel.from_pretrained(\n",
        "    model_name = new_model, # Your finetuned model name\n",
        "    max_seq_length = max_seq_length,\n",
        "    dtype = None,\n",
        "    load_in_4bit = True,\n",
        ")\n",
        "FastLanguageModel.for_inference(model)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3FU-E2kHrfo8"
      },
      "source": [
        "Now, test the fine-tuned model with a sample prompt by first using the tokenizer to generate the input ids, and then relying on the reloaded fine-tuned model to generate a response using `model.generate()`.\n",
        "\n",
        "Instead of waiting the entire time, you can view the generation token by token by using a `TextStreamer` for continuous inference."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 19,
      "metadata": {
        "id": "FQybHt5YnG9W"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "<bos>Below is an instruction that describes a task, paired with an\n",
            "input that provides further context. Write a response that appropriately\n",
            "completes the request.\n",
            "\n",
            "### Instruction:\n",
            "Continue the fibonnaci sequence.\n",
            "\n",
            "### Input:\n",
            "1, 1, 2, 3, 5, 8\n",
            "\n",
            "### Response:\n",
            "The next number in the Fibonacci sequence is 13.<eos>\n"
          ]
        }
      ],
      "source": [
        "from transformers import TextStreamer\n",
        "\n",
        "inputs = tokenizer(\n",
        "[\n",
        "    alpaca_prompt_template.format(\n",
        "        \"Continue the fibonnaci sequence.\", # instruction\n",
        "        \"1, 1, 2, 3, 5, 8\", # input\n",
        "        \"\", # output - leave this blank for generation!\n",
        "    )\n",
        "], return_tensors = \"pt\").to(\"cuda\")\n",
        "\n",
        "text_streamer = TextStreamer(tokenizer)\n",
        "_ = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 128)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0k7OvHwLtQz0"
      },
      "source": [
        "Congratulations! You've successfully fine-tuned Gemma using Unsloth. You've covered the entire process, from setting up the environment to training and testing the model."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vbJAeaK0tqMo"
      },
      "source": [
        "## What's next?\n",
        "Your next steps could include the following:\n",
        "\n",
        "- **Experiment with Different Datasets**: Try fine-tuning on other datasets in [Hugging Face](https://huggingface.co/docs/datasets/en/index) or your own data to adapt the model to various tasks or domains.\n",
        "\n",
        "- **Tune Hyperparameters**: Adjust training parameters (e.g., learning rate, batch size, epochs, LoRA settings) to optimize performance and improve training efficiency.\n",
        "\n",
        "- **Save model in different formats**: Check the different saving formats provided by Unsloth and use it in `llama.cpp` or a UI based system like `GPT4All`.\n",
        "\n",
        "By exploring these activities, you'll deepen your understanding and further enhance your fine-tuned Gemma model. Happy experimenting!"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "[Gemma_2]Finetune_with_Unsloth.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
